import os
os.chdir("../")
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import numpy as np
import random

from utils.pbmaze_config import env_config, multi_env_config
from phone_booth_collab_maze import PBCMaze
from multi_phone_booth_collab_maze import PBCMaze as MultiPBCMaze
from pbcmaze_belief_model import ReceiverBeliefModel, SenderBeliefModel

import seaborn as sns
import matplotlib.pyplot as plt

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)

f, axes = plt.subplots(1,2, sharex = True, sharey=True)

set_seed(0)
env = PBCMaze(env_args=env_config)
env.reset()
receiver_pi_0 = [0.2, 0.2, 0.2, 0.2, 0.2]
sender_pi_0 = [1/7, 1/7, 1/7, 1/7, 1/7, 1/7, 1/7]
rb_model = ReceiverBeliefModel(receiver_pi_0, env)
sb_model = SenderBeliefModel(sender_pi_0, env)
sender_possible_pos = sb_model.index_pos_dict
receiver_possible_pos = rb_model.index_pos_dict
sender_heatmap_values = np.zeros((3, env.lengths[0]))
# print(sender_possible_pos)
# print(receiver_possible_pos)

for s_k, s_v in sender_possible_pos.items():
    mi_sum = 0.0
    for r_k, r_v in receiver_possible_pos.items():
        # Set the environment config then compute MI
        env.agent0_loc = s_v
        env.load_env_config_obl(1, r_v, env.save_env_config())
        if(s_v == (9, 0) and r_v == (0, 0)):
            print(env.save_env_config())
        mi_sum += env.calculate_mi_reward(sender_pi_0)[1]
    # print(mi_sum)
    # print(mi_sum)
    sender_heatmap_values[s_v[1] + 1, s_v[0]] = mi_sum

print(sender_heatmap_values)

mask = np.zeros_like(sender_heatmap_values)
for i in range(mask.shape[0]):
    for j in range(mask.shape[1]):
        if((j, i - 1) not in sender_possible_pos.values()):
            mask[i, j] = True



num_runs = 1000
multi_sender_heatmap_values = np.zeros((3, env.lengths[0]))
for r in range(num_runs):
    set_seed(r)
    env = MultiPBCMaze(env_args=multi_env_config)
    env.reset()
    receiver_pi_0 = [0.2, 0.2, 0.2, 0.2, 0.2]
    sender_pi_0 = [1/7, 1/7, 1/7, 1/7, 1/7, 1/7, 1/7]
    rb_model = ReceiverBeliefModel(receiver_pi_0, env)
    sb_model = SenderBeliefModel(sender_pi_0, env)
    sender_possible_pos = sb_model.index_pos_dict
    receiver_possible_pos = rb_model.index_pos_dict
    for s_k, s_v in sender_possible_pos.items():
        mi_sum = 0.0
        for r_k, r_v in receiver_possible_pos.items():
            # Set the environment config then compute MI
            env.agent0_loc = s_v
            env.load_env_config_obl(1, r_v, env.save_env_config())
            # if(s_v == (9, 0) and r_v == (0, 0)):
            #     print(env.save_env_config())
            mi = env.calculate_mi_reward(sender_pi_0)[1]
            # if(mi > 0.0):
            #     print("sv: {} | rv: {} | mi: {}".format(s_v, r_v, mi))
            mi_sum += mi
        multi_sender_heatmap_values[s_v[1] + 1, s_v[0]] += mi_sum

multi_sender_heatmap_values /= num_runs

multi_mask = np.zeros_like(multi_sender_heatmap_values)
for i in range(multi_mask.shape[0]):
    for j in range(multi_mask.shape[1]):
        if((j, i - 1) not in sender_possible_pos.values()):
            multi_mask[i, j] = True

max_value = max(np.max(sender_heatmap_values), np.max(multi_sender_heatmap_values))
min_value = min(np.min(sender_heatmap_values), np.min(multi_sender_heatmap_values))
print(max_value)
print(min_value)
print(sender_heatmap_values.shape)
print(multi_sender_heatmap_values.shape)
cbar_ax = f.add_axes([.91,.3,.03,.4])
g1 = sns.heatmap(sender_heatmap_values.clip(0, 100), mask = mask, linewidth=2.0, annot = True, ax = axes[0], square=True, annot_kws={"fontsize":8}, cbar = True, cbar_ax = cbar_ax,  vmin = min_value,vmax = max_value)
g1.set(xticklabels=[], yticklabels=[])
g1.set(xlabel=None)
g1.tick_params(bottom=False, left = False)
g1.set(title = 'Single Functional Phone Booth')

g2 = sns.heatmap(multi_sender_heatmap_values.clip(0, 100), mask = multi_mask, linewidth=2.0, annot = True, ax = axes[1], square=True, annot_kws={"fontsize":8}, cbar = True, cbar_ax = cbar_ax,  vmin = min_value,vmax = max_value)
g2.set(xticklabels=[], yticklabels=[])
g2.set(xlabel=None)
g2.tick_params(bottom=False, left = False)
g2.set(title = 'Multiple Functional Phone Booths')

plt.show()
